#include "systemc.h"

#include "prims/comp_base.h"
#include "prims/comp_prims.h"
#include "utils/memory_utils.h"
#include "utils/prim_utils.h"
#include "utils/system_utils.h"

void Attention_f::print_self(string prefix) {
    cout << prefix << "<attention_forward>\n";
    cout << prefix << "\tB: " << B << ", T: " << T << ", C: " << C << endl;
    cout << prefix << "\tout_size: " << out_size << " , inp_size: " << inp_size
         << ", previous_inp_size: " << p_inp_size << endl;
    cout << prefix << "\toutput_offset: " << out_offset
         << ", input_offset: " << inp_offset << endl;
}

void Attention_f::initialize() {
    out_size = B * T * C / (1 + 2 / R);
    p_inp_size = B * T * C;
    inp_size = B * T * C + 2 * B * NH * T * T;

    dram_inp_size = (B * T * 3 * C + (DRAM_ALIGN - 1)) / DRAM_ALIGN;
    dram_out_size = (B * T * C / (1 + 2 / R) + (DRAM_ALIGN - 1)) / DRAM_ALIGN;
    dram_data_size = 0;

    if (datatype == INT8)
        data_byte = 1;
    else if (datatype == FP16)
        data_byte = 2;

    prea_offset = B * T * C + inp_offset;
    a_offset = B * NH * T * T + prea_offset;
}

void Attention_f::parse_json(json j) {
    /*
    inp_offset（必要） 等于上一个 matmul 的 out_offset 必要
    data_offset（无效） attention 操作不需要额外的权重
    out_offset（选填）: 可以根据inp_offset 计算，也可以手动设置 out_offset
    */
    B = find_var(j["B"]);
    T = find_var(j["T"]);
    C = find_var(j["C"]);
    NH = find_var(j["NH"]);
    R = find_var(j["R"]); // R默认为1

    initialize();
    
    if (j.contains("dram_address"))
        parse_address(j["dram_address"]);
    if (inp_offset == -1)
        inp_offset = (out_offset * 1024 - B * T * C) / 1024;

    if (out_offset == -1)
        assert(0 && "attention_forward: out_offset not set");

    // 添加以下三行以打印相关信息
    cout << "\033[1;33m" << "Attention_f" << "\033[0m" << endl;
    cout << "inp_offset: " << inp_offset << endl;
    cout << "out_offset: " << out_offset << endl;

    if (j.contains("sram_address"))
        parse_sram_label(j["sram_address"]);
}

int Attention_f::sram_utilization(DATATYPE datatype, int cid) {
    int total_sram = 0;

    int p_inp_sram =
        ceiling_division(B * T * 3 * C * data_byte * 8, get_sram_bitwidth(cid));
    int a_sram = ceiling_division(B * NH * T * T * data_byte * 8,
                                  get_sram_bitwidth(cid));
    int out_sram =
        ceiling_division(out_size * data_byte * 8, get_sram_bitwidth(cid));

    total_sram = (p_inp_sram + a_sram + out_sram) * get_sram_bitwidth(cid) / 8;

    return total_sram;
}

void Attention_f::deserialize(sc_bv<128> buffer) {
    inp_offset = buffer.range(23, 8).to_uint64();
    inp_offset *= 1024;
    out_offset = buffer.range(39, 24).to_uint64();
    out_offset *= 1024;
    B = buffer.range(55, 40).to_uint64();
    T = buffer.range(71, 56).to_uint64();
    C = buffer.range(87, 72).to_uint64();
    NH = buffer.range(103, 88).to_uint64();
    datatype = DATATYPE(buffer.range(105, 104).to_uint64());
    R = buffer.range(113, 106).to_uint64();

    initialize();
}

sc_bv<128> Attention_f::serialize() {
    sc_bv<128> d;
    d.range(7, 0) = sc_bv<8>(ATTENTION_F_TYPE);
    d.range(23, 8) = sc_bv<16>(inp_offset);
    d.range(39, 24) = sc_bv<16>(out_offset);
    d.range(55, 40) = sc_bv<16>(B);
    d.range(71, 56) = sc_bv<16>(T);
    d.range(87, 72) = sc_bv<16>(C);
    d.range(103, 88) = sc_bv<16>(NH);
    d.range(105, 104) = sc_bv<2>(datatype);
    d.range(113, 106) = sc_bv<8>(R);

    return d;
}

int Attention_f::task_core(TaskCoreContext &context) {
    // 所用时间
    u_int64_t dram_time = 0;
    u_int64_t overlap_time = 0;

    // 数据维度
    vector<int> data_size_input = {B * T * C}; // QKV input
    int data_size_preatt = B * NH * T * T;     // preatt
    int data_size_att = B * NH * T * T;        // att

    // 真实输出 A(1 + 2/R) = C
    int data_size_out = B * T * C / (1 + 2 / R);

    // dram地址
    u_int64_t dram_addr_tile = 0; //cid * dataset_words_per_tile;
    u_int64_t inp_global_addr = dram_addr_tile + inp_offset * data_byte;
    u_int64_t prea_global_addr = dram_addr_tile + prea_offset * data_byte;
    u_int64_t a_global_addr = dram_addr_tile + a_offset * data_byte;
    u_int64_t out_global_addr = dram_addr_tile + out_offset * data_byte;

    // 检查数据重利用
    bool input_reuse = false;
    if (datapass_label.indata[0][0] == '_') {
        input_reuse = true;
        datapass_label.indata[0] = datapass_label.indata[0].substr(1);
    }

    // 获取前缀label
    std::size_t pos = datapass_label.outdata.find_last_of('_');
    std::string prefix;
    if (pos != std::string::npos)
        prefix = datapass_label.outdata.substr(0, pos);
    else
        prefix = datapass_label.outdata;

    // 读入input数据
    check_input_data(context, dram_time, inp_global_addr, data_size_input);
    BETTER_PRINT(dram_time);

#if USE_SRAM == 1
    {
        // 写入preatt中间结果
        int temp_sram_addr = 0;
        int temp_sram_addr_prior = 0;
        temp_sram_addr_prior = temp_sram_addr;
        std::cout << "attention_forward sram_write_back_temp: temp_sram_addr: "
                  << temp_sram_addr << std::endl;
        sram_write_back_temp(context, data_byte * data_size_preatt,
                             temp_sram_addr, dram_time);
        std::cout
            << "attention_forward sram_read_generic_temp: temp_sram_addr: "
            << temp_sram_addr << std::endl;

        // 读出preatt，计算自然指数，写入att
        sram_read_generic_temp(context, data_byte * data_size_preatt,
                               temp_sram_addr_prior, dram_time);
        temp_sram_addr_prior = temp_sram_addr;
        std::cout << "attention_forward sram_write_back_temp: temp_sram_addr: "
                  << temp_sram_addr << std::endl;
        sram_write_back_temp(context, data_byte * data_size_att, temp_sram_addr,
                             dram_time);
        // 读出att
        std::cout
            << "attention_forward sram_read_generic_temp: temp_sram_addr: "
            << temp_sram_addr << std::endl;
        sram_read_generic_temp(context, data_byte * data_size_att,
                               temp_sram_addr_prior, dram_time);

        // 删除标签
        if (!input_reuse)
            sram_pos_locator->deletePair(datapass_label.indata[0]);

        BETTER_PRINT(dram_time);
    }
#endif

    // 计算overlap并写回output数据
    write_output_data(context, (uint64_t)B * NH * T * (T - 1) / 2 * (4 * C / NH + 5), 0,
                      dram_time, overlap_time, data_size_out, out_global_addr);

    BETTER_PRINT(overlap_time);

    return overlap_time;
}

int Attention_f::task() { return 0; }